-- Copyright 2022 Google LLC
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

module OccAnalysisSpec (spec) where

import Prelude hiding (unlines)
import Data.Maybe (catMaybes)
import Data.Text
import Test.Hspec

import ConcreteSyntax (parseUModule)
import AbstractSyntax (parseExpr)
import Err
import Inference (inferTopUExpr, synthTopE)
import Name
import OccAnalysis
import Occurrence
import Simplify
import SourceRename (renameSourceNamesUExpr)
import Types.Core
import Types.Imp (Backend (..))
import Types.Primitives
import Types.Source
import TopLevel
import QueryType

sourceTextToBlocks :: (Topper m, Mut n) => Text -> m n [SBlock n]
sourceTextToBlocks source = do
  let (UModule _ deps sourceBlocks) = parseUModule Main source
  mapM_ ensureModuleLoaded deps
  catMaybes <$> mapM sourceBlockToBlock sourceBlocks

sourceBlockToBlock :: (Topper m, Mut n) => SourceBlock -> m n (Maybe (SBlock n))
sourceBlockToBlock block = case sbContents block of
  Misc (ImportModule moduleName)  -> importModule moduleName >> return Nothing
  Command (EvalExpr (Printed _)) expr -> Just <$> (parseExpr expr >>= uExprToBlock)
  UnParseable _ s -> throw ParseErr s
  _ -> error $ "Unexpected SourceBlock " ++ pprint block ++ " in unit tests"

uExprToBlock :: (Topper m, Mut n) => UExpr 'VoidS -> m n (SBlock n)
uExprToBlock expr = do
  renamed <- renameSourceNamesUExpr expr
  typed <- inferTopUExpr renamed
  synthed <- synthTopE typed
  SimplifiedTopLam (TopLam _ _ (LamExpr Empty block)) (CoerceRecon _) <- simplifyTopBlock synthed
  return block

findRunIOAnnotation :: SBlock n -> LetAnn
findRunIOAnnotation (Abs decls _) = go decls where
  go :: Nest SDecl n l -> LetAnn
  go (Nest (Let _ (DeclBinding ann (PrimOp (Hof (TypedHof _ (RunIO _)))))) _) = ann
  go (Nest _ rest) = go rest
  go Empty = error "RunIO not found"

analyze :: EvalConfig -> TopStateEx -> [Text] -> IO LetAnn
analyze cfg env code = fst <$> runTopperM cfg env do
  [block] <- sourceTextToBlocks $ unlines code
  lam <- asTopLam $ LamExpr Empty block
  TopLam _ _ (LamExpr Empty block') <- analyzeOccurrences lam
  -- The RunIO is generated by simplifying `unreachable()` in the examples
  -- below.  If we want compound examples that have more than one RunIO block,
  -- we will need better pattern-matching.
  return $ findRunIOAnnotation block';

spec :: Spec
spec = do
  let cfg = EvalConfig LLVM [LibBuiltinPath] Nothing Nothing Nothing Optimize PrintCodegen
  -- or just initTopState, to always compile the prelude during unit tests?
  init_env <- runIO loadCache
  (_, env) <- runIO $ runTopperM cfg init_env $ ensureModuleLoaded Prelude
  describe "Occurrence analysis" do
    it "counts a reference as a use" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  xs"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (0, One))
    it "counts indexing in a for as one use" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  for i. xs[i]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    it "counts indexing depth in nested fors" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => (Fin 3) => Float = unreachable()"
        , "  for i j. xs[i, j]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (2, One))
    it "counts two array uses" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  (for i. xs[i], for j. xs[j])"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, Two))
    it "counts array and non-array uses" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  (for i. xs[i], xs)"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, Two))
    it "counts different case arms as static but not dynamic uses" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  if unreachable()"
        , "    then for i. xs[i]"
        , "    else for j. xs[j]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, One))
    it "understands one index injection" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Fin 10, Fin 4) => Float = unreachable()"
        , "  for j. xs[Left j]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    it "understands distinct index injections" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Fin 10, Fin 4) => Float = unreachable()"
        , "  (for i. xs[Left i], for j. xs[Right j])"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, One))
    it "detects and eschews index arithmetic" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 4) => Float = unreachable()"
        , "  for i:(Fin 3). xs[(ordinal i + 1)@_]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "detects non-nested single-uses cases despite index arithmetic" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 4) => Float = unreachable()"
        , "  xs[1@_]"
        ]
      -- Arguably, should be able to prove that zero levels of exposed indexing
      -- (not one) suffice for inlining xs to be safe here, but the current
      -- occurrence analysis doesn't prove it yet.
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    it "detects nested single-uses cases despite index arithmetic" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 4) => (Fin 3) => Float = unreachable()"
        , "  for i:(Fin 3). xs[(ordinal i + 1)@_, i]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (2, One))
    it "detects repeated access" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 4) => Float = unreachable()"
        , "  for i j:(Fin 5). xs[i]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "does not count the `trace` pattern as repeated access" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 4) => (Fin 4) => Float = unreachable()"
        , "  for i. xs[i, i]"
        ]
      -- Arguably, should be able to prove that only one level of exposed
      -- indexing (not two) suffice for inlining xs to be safe here, but doesn't
      -- prove it yet.
      ann `shouldBe` OccInfoPure (UsageInfo One (2, One))
    it "solves safe sum-over-max" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Either(Fin 4, Fin 4), Either(Fin 4, Fin 4)) => Float = unreachable()"
        , "  ys = for i."
        , "    if unreachable()"
        , "      then xs[Left  (Left  i)]"
        , "      else xs[Left  (Right i)]"
        , "  zs = for j."
        , "    if unreachable()"
        , "      then xs[Right (Left  j)]"
        , "      else xs[Right (Right j)]"
        , "  (ys, zs)"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo (Bounded 4) (1, One))
    it "solves unsafe sum-over-max" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Either(Fin 4, Fin 4), Either(Fin 4, Fin 4)) => Float = unreachable()"
        , "  ys = for i."
        , "    if unreachable()"
        , "      then xs[Left  (Left  i)]"
        , "      else xs[Left  (Right i)]"
        , "  zs = for j."
        , "    if unreachable()"
        , "      then xs[Right (Left  j)]"
        , "      else xs[Left  (Left  j)]"
        , "  (ys, zs)"
        ]
      -- One of the code paths hits the same elements(s)
      ann `shouldBe` OccInfoPure (UsageInfo (Bounded 4) (1, Two))
    it "does not penalize referring to indices in scope" do
      ann <- analyze cfg env
        [ ":p"
        , "  j = 1@(Fin 3)"
        , "  xs : (Fin 10) => (Fin 3) => Float = unreachable()"
        , "  for i. xs[i, j]"
        ]
      -- Arguably, should be able to prove that only one level of exposed
      -- indexing (not two) suffice for inlining xs to be safe here, but doesn't
      -- prove it yet.
      ann `shouldBe` OccInfoPure (UsageInfo One (2, One))
    it "is conservative about potential collisions between indices in scope" do
      ann <- analyze cfg env
        [ ":p"
        , "  j = 1@(Fin 3)"
        , "  k = 1@(Fin 3)"
        , "  xs : (Fin 10) => (Fin 3) => Float = unreachable()"
        , "  (for i. xs[i, j], for i. xs[i, k])"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (2, Two))
    it "is not confused by potential collisions at an early indexing depth" do
      ann <- analyze cfg env
        [ ":p"
        , "  j = 1@(Fin 10)"
        , "  k = 1@(Fin 10)"
        , "  xs : (Fin 10) => (Fin 3) => Float = unreachable()"
        , "  (for i. xs[j, i], for i. xs[k, i])"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (2, Two))
    it "does not crash on indexing by case-bound binders" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  for i."
        , "    case i of"
        , "      (Left  j) -> xs[j]"
        , "      (Right k) -> xs[k]"
        ]
      -- TODO actually, it should be possible to get this to be Bounded 2 rather
      -- than Unbounded.  We get Unbounded because `occAlt` assumes that the
      -- binder of a case (in this case `j`) is an "unknown function" of the
      -- scrutinee.  That's conservative, but in reality the function is very
      -- well known, and even injective, but just not total (and not the same
      -- across arms of the `case`).  However, trying to fix that would unmask
      -- another bug, which is that mapping `case` to `max` is only correct if
      -- the scrutinee doesn't depend on any binders being iterated.  In fact,
      -- in this example it does, so both arms of the `case` end up being taken,
      -- albeit at different iterations of the `i` loop.  To analyze this
      -- correctly, we would need to know that `j` and `k` may collide across
      -- `case` arms, though not within an arm.
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, Unbounded))
    it "does not crash on indexing by state-effect-bound binders" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  with_state (0 @ Fin 10) \\ref."
        , "    xs[get ref]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    it "assumes state references touch everything" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  with_state (0 @ Fin 10) \\ref."
        , "    for i:(Fin 3)."
        , "      xs[get ref]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "assumes state references touch everything even if the initializer doesn't" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  for i:(Fin 10)."
        , "    with_state i \\ref."
        , "      xs[get ref]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "analyzes through accum effects" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  run_accum (AddMonoid Float) \\ref."
        , "    for i."
        , "      ref += xs[i]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    -- TODO Should probably construct an example of indexing with the ref of a
    -- run_accum (and, for that matter, with the handle of a run_state or
    -- run_accum), but I'm not sure how to make either of those well-typed.
    it "does not crash on indexing by reader-effect-bound binders" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        , "  with_reader (0 @ Fin 10) \\ref."
        , "    xs[ask ref]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, One))
    it "understands that while loops access repeatedly" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : (Fin 10) => Float = unreachable()"
        -- with_state prevents the access from appearing dead
        , "  with_state 0 \\ref."
        , "    while \\."
        , "      for i."
        , "        ref := (get ref) + xs[i]"
        , "      False"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "does not crash on index-defining bindings" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Fin 10, Fin 3) => Float = unreachable()"
        , "  for i:(Fin 10)."
        , "    j = Left(unsafe_from_ordinal(1 + ordinal i))"
        , "    xs[j]"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo One (1, Unbounded))
    it "understands indexing by literals" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Fin 10, Fin 3) => Float = unreachable()"
        , "  (xs[0@_], xs[0@_])"
        ]
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, Two))
    it "is conservative about distict literal indices" do
      ann <- analyze cfg env
        [ ":p"
        , "  xs : Either(Fin 10, Fin 3) => Float = unreachable()"
        , "  (xs[0@_], xs[1@_])"
        ]
      -- TODO In this case, we should be able to detect non-collision of
      -- indexing by 0 and by 1; but assuming they may collide is safe.
      ann `shouldBe` OccInfoPure (UsageInfo Two (1, Two))
    it "is conservative about ix dicts" do
      ann <- analyze cfg env
        [ ":p"
        , "  n : Nat = unreachable()"
        , "  xs : (Fin n) => Nat = iota (Fin n)"
        , "  sum xs"
        ]
      -- We consider `n` statically unbounded because we assume that its ix dict
      -- may be inlined uncontrollably later.
      ann `shouldBe` OccInfoPure (UsageInfo Unbounded (0, Unbounded))
